{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright 2025 Arm Limited and/or its affiliates.\n",
    "#\n",
    "# This source code is licensed under the BSD-style license found in the\n",
    "# LICENSE file in the root directory of this source tree."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TOSA delegate flow example\n",
    "\n",
    "This guide walks through the complete process of running a module on Arm TOSA using ExecuTorch, with a focus on TOSA lowering exploration. \n",
    "This workflow is intended for validating and experimenting with model lowering to TOSA, and is aimed at contributors and developers, rather than production deployment.\n",
    "It’s important to note that the compilation flow and passes applied can vary based on the target, so this flow does not necessarily produce TOSA flatbuffers and PTE files which are optimal (or even compatible) with any one target.\n",
    "If something is not working for you, please raise a GitHub issue and tag Arm.\n",
    "\n",
    "Before you begin:\n",
    "1. (In a clean virtual environment with a compatible Python version) Install executorch using `./install_executorch.sh`\n",
    "2. Install Arm TOSA dependencies using `examples/arm/setup.sh --disable-ethos-u-deps`\n",
    "\n",
    "With all commands executed from the base `executorch` folder.\n",
    "\n",
    "\n",
    "\n",
    "*Some scripts in this notebook produces long output logs: Configuring the 'Customizing Notebook Layout' settings to enable 'Output:scrolling' and setting 'Output:Text Line Limit' makes this more manageable*"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## AOT Flow\n",
    "\n",
    "The first step is creating the PyTorch module and exporting it. Exporting converts the python code in the module into a graph structure. The result is still runnable python code, which can be displayed by printing the `graph_module` of the exported program.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "print(torch.__version__)\n",
    "\n",
    "class Add(torch.nn.Module):\n",
    "    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n",
    "        return x + y\n",
    "\n",
    "example_inputs = (torch.ones(1,1,1,1),torch.ones(1,1,1,1))\n",
    "\n",
    "model = Add()\n",
    "model = model.eval()\n",
    "exported_program = torch.export.export(model, example_inputs)\n",
    "graph_module = exported_program.graph_module\n",
    "\n",
    "_ = graph_module.print_readable()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TOSA backend supports both INT and FP targets.\n",
    "\n",
    "To lower the graph_module for FP targets using the TOSA backend, we run it through the default FP lowering pipeline.\n",
    "\n",
    "FP lowering can be customized for different subgraphs; the sequence shown here is the recommended workflow for TOSA. Because we are staying in floating-point precision, no calibration with example inputs is required.\n",
    "\n",
    "If you print the module again, you will see that nodes are left in FP form (or annotated with any necessary casts) without any quantize/dequantize wrappers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec\n",
    "from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e\n",
    "from pathlib import Path\n",
    "\n",
    "target = \"TOSA-1.0+FP\"\n",
    "base_name = \"tosa_simple_example\"\n",
    "cwd_dir = Path.cwd()\n",
    "\n",
    "# Create a compilation spec describing the target for configuring the quantizer\n",
    "# Dump intermediate artifacts (in this case TOSA flat buffers) to specified location\n",
    "compile_spec = TosaCompileSpec(target).dump_intermediate_artifacts_to(str(cwd_dir / base_name))\n",
    "\n",
    "_ = graph_module.print_readable()\n",
    "\n",
    "# Create a new exported program using the quantized_graph_module\n",
    "lowered_exported_program = torch.export.export(graph_module, example_inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To lower the graph_module for INT targets using the TOSA backend, we apply the arm_quantizer.\n",
    "\n",
    "Quantization can be performed in various ways and tailored to different subgraphs; the sequence shown here represents the recommended workflow for TOSA.\n",
    "\n",
    "This step also requires calibrating the module with representative inputs.\n",
    "\n",
    "If you print the module again, you’ll see that each node is now wrapped in quantization/dequantization nodes that embed the calculated quantization parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec\n",
    "from executorch.backends.arm.quantizer import (\n",
    "    TOSAQuantizer,\n",
    "    get_symmetric_quantization_config,\n",
    ")\n",
    "from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e\n",
    "from pathlib import Path\n",
    "\n",
    "target = \"TOSA-1.0+INT\"\n",
    "base_name = \"tosa_simple_example\"\n",
    "cwd_dir = Path.cwd()\n",
    "\n",
    "# Create a compilation spec describing the target for configuring the quantizer\n",
    "# Dump intermediate artifacts (in this case TOSA flat buffers) to specified location\n",
    "compile_spec = TosaCompileSpec(target).dump_intermediate_artifacts_to(str(cwd_dir / base_name))\n",
    "\n",
    "# Create and configure quantizer to use a symmetric quantization config globally on all nodes\n",
    "quantizer = TOSAQuantizer(compile_spec)\n",
    "operator_config = get_symmetric_quantization_config()\n",
    "quantizer.set_global(operator_config)\n",
    "\n",
    "# Post training quantization\n",
    "quantized_graph_module = prepare_pt2e(graph_module, quantizer)\n",
    "quantized_graph_module(*example_inputs) # Calibrate the graph module with the example input\n",
    "quantized_graph_module = convert_pt2e(quantized_graph_module)\n",
    "\n",
    "_ = quantized_graph_module.print_readable()\n",
    "\n",
    "# Create a new exported program using the quantized_graph_module\n",
    "lowered_exported_program = torch.export.export(quantized_graph_module, example_inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The lowering in the TOSABackend happens in four steps:\n",
    "\n",
    "1. **Lowering to core Aten operator set**: Transform module to use a subset of operators applicable to edge devices. \n",
    "2. **Partitioning**: Find subgraphs which are supported for running on TOSA\n",
    "3. **Lowering to TOSA compatible operator set**: Perform transforms to make the TOSA subgraph(s) compatible with TOSA operator set\n",
    "4. **Serialization to TOSA**: Compiles the graph module into a TOSA graph \n",
    "Step 4 also prints a Network summary for each processed subgraph.\n",
    "\n",
    "All of this happens behind the scenes in `to_edge_transform_and_lower`. Printing the graph module shows that what is left in the graph is two quantization nodes for `x` and `y` going into an `executorch_call_delegate` node, followed by a dequantization node."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from executorch.backends.arm.tosa.partitioner import TOSAPartitioner\n",
    "from executorch.exir import (\n",
    "    EdgeCompileConfig,\n",
    "    ExecutorchBackendConfig,\n",
    "    to_edge_transform_and_lower,\n",
    ")\n",
    "from executorch.extension.export_util.utils import save_pte_program\n",
    "\n",
    "# Create partitioner from compile spec\n",
    "partitioner = TOSAPartitioner(compile_spec)\n",
    "\n",
    "# Lower the exported program to the TOSA backend\n",
    "edge_program_manager = to_edge_transform_and_lower(\n",
    "            lowered_exported_program,\n",
    "            partitioner=[partitioner],\n",
    "            compile_config=EdgeCompileConfig(\n",
    "                _check_ir_validity=False,\n",
    "            ),\n",
    "        )\n",
    "\n",
    "# Convert edge program to executorch\n",
    "executorch_program_manager = edge_program_manager.to_executorch(\n",
    "            config=ExecutorchBackendConfig(extract_delegate_segments=False)\n",
    "        )\n",
    "\n",
    "executorch_program_manager.exported_program().graph_module.print_readable()\n",
    "\n",
    "# Save pte file\n",
    "pte_name = base_name + \".pte\"\n",
    "pte_path = cwd_dir / base_name / pte_name\n",
    "save_pte_program(executorch_program_manager, str(pte_path))\n",
    "assert pte_path.exists(), \"Build failed; no .pte-file found\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use TOSA reference model to verify TOSA graph\n",
    "\n",
    "After the AOT compilation flow is done, the resulting lowered TOSA graph can be verified using the TOSA reference model tool."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import subprocess\n",
    "import tosa_reference_model as reference_model\n",
    "from executorch.backends.arm.test.runner_utils import TosaReferenceModelDispatch\n",
    "\n",
    "# Run TOSA graph through reference model using sample inputs\n",
    "with TosaReferenceModelDispatch():\n",
    "    executorch_program_manager.exported_program().module()(*example_inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
